{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 基于FastText的意图分类\n",
    "\n",
    "在这里我们训练一个FastText意图识别模型，并把训练好的模型存放在模型文件里。 意图识别实际上是文本分类任务，需要标注的数据：每一个句子需要对应的标签如闲聊型的，任务型的。但在这个项目中，我们并没有任何标注的数据，而且并不需要搭建闲聊机器人。所以这里搭建的FastText模型只是一个dummy模型，没有什么任何的作用。这个模块只是为了项目的完整性，也让大家明白FastText如何去使用，仅此而已。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import fasttext\n",
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "import numpy as np\n",
    "import pickle"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 读取数据：导入在preprocessor.ipynb中生成的data/question_answer_pares.pkl文件，并将其保存在变量QApares中\n",
    "with open('data/question_answer_pares.pkl','rb') as f:\n",
    "    QApares = pickle.load(f)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 这一文档的目的是为了用fasttext来进行意图识别，将问题分为任务型和闲聊型两类\n",
    "# 训练数据集是任务型还是闲聊型本应由人工打上标签，但这并不是我们的重点。我们的重点是教会大家如何用fasttext来进行意图识别\n",
    "# 所以这里我们为数据集随机打上0或1的标签，而不去管实际情况如何\n",
    "\n",
    "#fasttext的输入格式为：单词1 单词2 单词3 ... 单词n __label__标签号\n",
    "#我们将问题数据集整理为fasttext需要的输入格式并为其随机打上标签并将结果保存在data/fasttext/fasttext_train.txt和data/fasttext/fasttext_test.txt中\n",
    "with open('data/fasttext/fasttext_train.txt','w') as f:\n",
    "    for content in QApares[:int(0.7*len(QApares))].dropna().question_after_preprocessing:\n",
    "        f.write('%s __label__%d\\n' % (' '.join(content), np.random.randint(0,2)))\n",
    "with open('data/fasttext/fasttext_test.txt','w') as f:\n",
    "    for content in QApares[int(0.7*len(QApares)):].dropna().question_after_preprocessing:\n",
    "        f.write('%s __label__%d\\n' % (' '.join(content), np.random.randint(0,2)))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "#使用fasttext进行意图识别，并将模型保存在classifier中\n",
    "classifier = fasttext.train_supervised('data/fasttext/fasttext_train.txt',     #训练数据文件路径\n",
    "                                       label=\"__label__\",      #类别前缀\n",
    "                                       dim=100,       #向量维度\n",
    "                                       epoch=5,       #训练轮次\n",
    "                                       lr=0.1,        #学习率\n",
    "                                       wordNgrams=2,      #n-gram个数\n",
    "                                       loss='softmax',    #损失函数类型\n",
    "                                       thread=5,          #线程个数, 每个线程处理输入数据的一段, 0号线程负责loss输出\n",
    "                                       verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(('__label__1',), array([0.50656599]))"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#使用训练好的fasttext模型进行预测\n",
    "classifier.predict('今天 月亮 真 圆 啊')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(30000, 0.5006666666666667, 0.5006666666666667)"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#使用训练好的fasttext模型对测试集文件进行评估\n",
    "classifier.test('data/fasttext/fasttext_test.txt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "#保存模型\n",
    "classifier.save_model('model/fasttext.ftz')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(('__label__1',), array([0.5010941]))"
      ]
     },
     "execution_count": 31,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 测试\n",
    "classifier.predict('商品 有优惠吗')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(('__label__0',), array([0.50158423]))"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 测试\n",
    "classifier.predict('发什么快递')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.10"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
